# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting code to AST.

Adapted from Tangent.
"""

import ast
import inspect
import linecache
import re
import sys
import textwrap
import tokenize

import astunparse
import gast
import six

from nvidia.dali._autograph.pyct import errors
from nvidia.dali._autograph.pyct import inspect_utils


PY2_PREAMBLE = textwrap.dedent(
    """
"""
)
PY3_PREAMBLE = ""
MAX_SIZE = 0

if sys.version_info >= (3, 9):
    astunparse = ast  # noqa: F811

if sys.version_info >= (3,):
    STANDARD_PREAMBLE = PY3_PREAMBLE
    MAX_SIZE = sys.maxsize
else:
    STANDARD_PREAMBLE = PY2_PREAMBLE
    MAX_SIZE = sys.maxint

STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count("__future__")


_LEADING_WHITESPACE = re.compile(r"\s*")


def _unfold_continuations(code_string):
    """Removes any backslash line continuations from the code."""
    return code_string.replace("\\\n", "")


def dedent_block(code_string):
    """Dedents a code so that its first line starts at row zero."""

    code_string = _unfold_continuations(code_string)

    token_gen = tokenize.generate_tokens(six.StringIO(code_string).readline)

    block_indentation = None
    tokens = []
    try:
        for tok in token_gen:
            tokens.append(tok)
    except tokenize.TokenError:
        # Resolution of lambda functions may yield incomplete code, which can
        # in turn generate this error. We silently ignore this error because the
        # parser may still be able to deal with it.
        pass

    for tok in tokens:
        tok_type, tok_string, _, _, _ = tok
        if tok_type == tokenize.INDENT:
            block_indentation = tok_string
            block_level = len(block_indentation)
            break
        elif tok_type not in (tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
            block_indentation = ""
            break

    if not block_indentation:
        return code_string

    block_level = len(block_indentation)
    first_indent_uses_tabs = "\t" in block_indentation
    for i, tok in enumerate(tokens):
        tok_type, tok_string, _, _, _ = tok
        if tok_type == tokenize.INDENT:
            if (" " in tok_string and first_indent_uses_tabs) or (
                "\t" in tok_string and not first_indent_uses_tabs
            ):
                # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
                # See:
                # https://docs.python.org/3/reference/lexical_analysis.html#indentation
                raise errors.UnsupportedLanguageElementError(
                    "code mixing tabs and spaces for indentation is not allowed"
                )
            if len(tok_string) >= block_level:
                tok_string = tok_string[block_level:]
            tokens[i] = (tok_type, tok_string)

    new_code = tokenize.untokenize(tokens)

    # Note: untokenize respects the line structure, but not the whitespace within
    # lines. For example, `def foo()` may be untokenized as `def foo ()`
    # So instead of using the output of dedent, we match the leading whitespace
    # on each line.
    dedented_code = []
    for line, new_line in zip(code_string.split("\n"), new_code.split("\n")):
        original_indent = re.match(_LEADING_WHITESPACE, line).group()
        new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
        if len(original_indent) > len(new_indent):
            dedented_line = line[len(original_indent) - len(new_indent) :]
        else:
            dedented_line = line
        dedented_code.append(dedented_line)
    new_code = "\n".join(dedented_code)

    return new_code


def parse_entity(entity, future_features):
    """Returns the AST and source code of given entity.

    Args:
      entity: Any, Python function/method/class
      future_features: Iterable[Text], future features to use (e.g.
        'print_statement'). See
        https://docs.python.org/2/reference/simple_stmts.html#future

    Returns:
      gast.AST, Text: the parsed AST node; the source code that was parsed to
      generate the AST (including any prefixes that this function may have added).
    """
    if inspect_utils.islambda(entity):
        return _parse_lambda(entity)

    try:
        original_source = inspect_utils.getimmediatesource(entity)
    except OSError as e:
        raise errors.InaccessibleSourceCodeError(
            f"Unable to locate the source code of {entity}. Note that functions"
            " defined in certain environments, like the interactive Python shell,"
            " do not expose their source code. If that is the case, you should"
            " define them in a .py source file. If you are certain the code is"
            " graph-compatible, wrap the call in the do_not_convert decorator."
            f" Original error: {e}"
        )

    source = dedent_block(original_source)

    future_statements = tuple("from __future__ import {}".format(name) for name in future_features)
    source = "\n".join(future_statements + (source,))

    return parse(source, preamble_len=len(future_features)), source


def _without_context(node, lines, minl, maxl):
    """Returns a clean node and source code without indenting and context."""
    for n in gast.walk(node):
        lineno = getattr(n, "lineno", None)
        if lineno is not None:
            n.lineno = lineno - minl
        end_lineno = getattr(n, "end_lineno", None)
        if end_lineno is not None:
            n.end_lineno = end_lineno - minl

    code_lines = lines[minl - 1 : maxl]

    # Attempt to clean up surrounding context code.

    end_col_offset = getattr(node, "end_col_offset", None)
    if end_col_offset is not None:
        # This is only available in 3.8.
        code_lines[-1] = code_lines[-1][:end_col_offset]

    col_offset = getattr(node, "col_offset", None)
    if col_offset is None:
        # Older Python: try to find the "lambda" token. This is brittle.
        match = re.search(r"(?<!\w)lambda(?!\w)", code_lines[0])
        if match is not None:
            col_offset = match.start(0)

    if col_offset is not None:
        code_lines[0] = code_lines[0][col_offset:]

    code_block = "\n".join([c.rstrip() for c in code_lines])

    return node, code_block


def _arg_name(node):
    if node is None:
        return None
    if isinstance(node, gast.Name):
        return node.id
    assert isinstance(node, str)
    return node


def _node_matches_argspec(node, func):
    """Returns True is node fits the argspec of func."""
    # TODO(mdan): Use just inspect once support for Python 2 is dropped.
    arg_spec = inspect.getfullargspec(func)

    node_args = tuple(_arg_name(arg) for arg in node.args.args)
    if node_args != tuple(arg_spec.args):
        return False

    if arg_spec.varargs != _arg_name(node.args.vararg):
        return False

    if arg_spec.varkw != _arg_name(node.args.kwarg):
        return False

    node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
    if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
        return False

    return True


def _parse_lambda(lam):
    """Returns the AST and source code of given lambda function.

    Args:
      lam: types.LambdaType, Python function/method/class

    Returns:
      gast.AST, Text: the parsed AST node; the source code that was parsed to
      generate the AST (including any prefixes that this function may have added).
    """
    # TODO(mdan): Use a fast path if the definition is not multi-line.
    # We could detect that the lambda is in a multi-line expression by looking
    # at the surrounding code - an surrounding set of parentheses indicates a
    # potential multi-line definition.

    mod = inspect.getmodule(lam)
    f = inspect.getsourcefile(lam)
    def_line = lam.__code__.co_firstlineno

    # This method is more robust that just calling inspect.getsource(mod), as it
    # works in interactive shells, where getsource would fail. This is the
    # same procedure followed by inspect for non-modules:
    # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
    lines = linecache.getlines(f, mod.__dict__)
    source = "".join(lines)

    # Narrow down to the last node starting before our definition node.
    all_nodes = parse(source, preamble_len=0, single_node=False)
    search_nodes = []
    for node in all_nodes:
        # Also include nodes without a line number, for safety. This is defensive -
        # we don't know whether such nodes might exist, and if they do, whether
        # they are not safe to skip.
        # TODO(mdan): Replace this check with an assertion or skip such nodes.
        if getattr(node, "lineno", def_line) <= def_line:
            search_nodes.append(node)
        else:
            # Found a node starting past our lambda - can stop the search.
            break

    # Extract all lambda nodes from the shortlist.
    lambda_nodes = []
    for node in search_nodes:
        lambda_nodes.extend(n for n in gast.walk(node) if isinstance(n, gast.Lambda))

    # Filter down to lambda nodes which span our actual lambda.
    candidates = []
    for ln in lambda_nodes:
        minl, maxl = MAX_SIZE, 0
        for n in gast.walk(ln):
            minl = min(minl, getattr(n, "lineno", minl))
            lineno = getattr(n, "lineno", maxl)
            end_lineno = getattr(n, "end_lineno", None)
            if end_lineno is not None:
                # end_lineno is more precise, but lineno should almost always work too.
                lineno = end_lineno
            maxl = max(maxl, lineno)
        if minl <= def_line <= maxl:
            candidates.append((ln, minl, maxl))

    # Happy path: exactly one node found.
    if len(candidates) == 1:
        ((node, minl, maxl),) = candidates  # pylint:disable=unbalanced-tuple-unpacking
        return _without_context(node, lines, minl, maxl)

    elif not candidates:
        lambda_codes = "\n".join([unparse(l) for l in lambda_nodes])
        raise errors.UnsupportedLanguageElementError(
            f"could not parse the source code of {lam}:"
            f" no matching AST found among candidates:\n{lambda_codes}"
        )

    # Attempt to narrow down selection by signature is multiple nodes are found.
    matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
    if len(matches) == 1:
        ((node, minl, maxl),) = matches
        return _without_context(node, lines, minl, maxl)

    # Give up if could not narrow down to a single node.
    matches = "\n".join(
        "Match {}:\n{}\n".format(i, unparse(node, include_encoding_marker=False))
        for i, (node, _, _) in enumerate(matches)
    )
    raise errors.UnsupportedLanguageElementError(
        f"could not parse the source code of {lam}: found multiple definitions"
        " with identical signatures at the location. This error"
        " may be avoided by defining each lambda on a single line and with"
        f" unique argument names. The matching definitions were:\n{matches}"
    )


# TODO(mdan): This should take futures as input instead.
def parse(src, preamble_len=0, single_node=True):
    """Returns the AST of given piece of code.

    Args:
      src: Text
      preamble_len: Int, indicates leading nodes in the parsed AST which should be
        dropped.
      single_node: Bool, whether `src` is assumed to be represented by exactly one
        AST node.

    Returns:
      ast.AST
    """
    module_node = gast.parse(src)
    nodes = module_node.body
    if preamble_len:
        nodes = nodes[preamble_len:]
    if single_node:
        if len(nodes) != 1:
            raise ValueError("expected exactly one node, got {}".format(nodes))
        return nodes[0]
    return nodes


def parse_expression(src):
    """Returns the AST of given identifier.

    Args:
      src: A piece of code that represents a single Python expression
    Returns:
      A gast.AST object.
    Raises:
      ValueError: if src does not consist of a single Expression.
    """
    src = STANDARD_PREAMBLE + src.strip()
    node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
    if __debug__:
        if not isinstance(node, gast.Expr):
            raise ValueError("expected exactly one node of type Expr, got {}".format(node))
    return node.value


def unparse(node, indentation=None, include_encoding_marker=True):
    """Returns the source code of given AST.

    Args:
      node: The code to compile, as an AST object.
      indentation: Unused, deprecated. The returning code will always be indented
        at 4 spaces.
      include_encoding_marker: Bool, whether to include a comment on the first
        line to explicitly specify UTF-8 encoding.

    Returns:
      code: The source code generated from the AST object
      source_mapping: A mapping between the user and AutoGraph generated code.
    """
    del indentation  # astunparse doesn't allow configuring it.
    if not isinstance(node, (list, tuple)):
        node = (node,)

    codes = []
    if include_encoding_marker:
        codes.append("# coding=utf-8")
    for n in node:
        if isinstance(n, gast.AST):
            ast_n = gast.gast_to_ast(n)
        else:
            ast_n = n

        if astunparse is ast:
            ast.fix_missing_locations(ast_n)  # Only ast needs to call this.
        codes.append(astunparse.unparse(ast_n).strip())

    return "\n".join(codes)
